import torch.nn as nn
import torch
from .basiclinear import BasicLinear
from ..op import separable_linear_bmm, separable_linear_custom


class SeparableLinear(BasicLinear):

    def __init__(
        self,
        in_features,
        out_features,
        bias,
        return_bias,
        config,
        init_config,
        device="cuda",
    ):
        super().__init__(
            in_features, out_features, bias, return_bias, config, init_config, device
        )
        self.nblocks = config["nblocks"]
        assert self.in_features % self.nblocks == 0
        assert self.out_features % self.nblocks == 0

        in_blksz = self.in_features // self.nblocks
        out_blksz = self.out_features // self.nblocks

        if self.in_features < self.out_features:
            self.blkdiag1 = nn.Parameter(
                torch.empty(self.nblocks, in_blksz, in_blksz, device=device)
            )
            self.blkdiag2 = nn.Parameter(
                torch.empty(self.nblocks, out_blksz, in_blksz, device=device)
            )
        else:
            self.blkdiag1 = nn.Parameter(
                torch.empty(self.nblocks, out_blksz, in_blksz, device=device)
            )
            self.blkdiag2 = nn.Parameter(
                torch.empty(self.nblocks, out_blksz, out_blksz, device=device)
            )
        self._init_weights()
        self.post_init()

    def get_weights(
        self,
    ):
        return [self.blkdiag1, self.blkdiag2]

    @torch.no_grad()
    def post_init(
        self,
    ):
        if self.config.init.post_init == "ortho":
            for i in range(self.nblocks):
                U, S, Vh = torch.linalg.svd(self.blkdiag1.data[i], full_matrices=False)
                self.blkdiag1.data[i] = torch.mm(U, Vh)
                U, S, Vh = torch.linalg.svd(self.blkdiag2.data[i], full_matrices=False)
                self.blkdiag2.data[i] = torch.mm(U, Vh)

        # init guide linear
        if hasattr(self, "guide_linear"):
            self.guide_linear.data = torch.mm(
                torch.block_diag(*torch.unbind(self.blkdiag2.data, dim=0)),
                torch.block_diag(*torch.unbind(self.blkdiag1.data, dim=0)),
            )

    @torch.no_grad()
    def old_frobgrad(self, wd=0.0):
        # we find that decay the product does not have clear benefits but brings additional computation, so deprecate this one here
        if wd:
            if self.config.guide and self.gamma:
                self.guide_linear.data -= wd * self.gamma * self.guide_linear.data
                wd *= 1.0 - self.gamma
                self.gamma -= self.gamma_decay
                self.gamma = max(self.gamma, torch.tensor(0.0).cuda())

            Vh = torch.block_diag(*torch.unbind(self.blkdiag1.data, dim=0))
            U = torch.block_diag(*torch.unbind(self.blkdiag2.data, dim=0))
            h = Vh.shape[0]
            indice = (
                torch.arange(h)
                .reshape(h // self.nblocks, self.nblocks)
                .transpose(0, 1)
                .reshape(-1)
            )
            P = torch.eye(h).cuda()[indice]
            tmp = torch.chain_matmul(U, P.T, Vh)
            tmp1 = torch.chain_matmul(tmp, Vh.T, P)
            tmp2 = torch.chain_matmul(P, U.T, tmp)
            k, q, p = self.blkdiag1.shape
            l, s, r = self.blkdiag2.shape
            for i in range(k):
                self.blkdiag1.data[i] -= (
                    wd * tmp2[i * q : (i + 1) * q, i * p : (i + 1) * p]
                )
            for i in range(l):
                self.blkdiag2.data[i] -= (
                    wd * tmp1[i * s : (i + 1) * s, i * r : (i + 1) * r]
                )

    def forward(self, input):
        out = separable_linear_custom(input, self.blkdiag1, self.blkdiag2)
        return self.forward_guide_layer(input, out)

    def extra_repr(self) -> str:
        return f"blockdiag1={self.blkdiag1.shape}, blockdiag2={self.blkdiag2.shape}, bias={self.bias is not None}, guide={self.training_config.enabled}"
